import torch
import torchvision.models as models
import torch.nn as nn
from util import Parameter

def prepare50(prepared_path:str):

    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

    # model.conv1 = nn.Conv2d(
    #     1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    # )
    model.fc = nn.Linear(in_features=2048, out_features=Parameter.NUM_CLASSES, bias=True)
    # torch.save(model, prepared_path)
    print(model)

if __name__=="__main__":
    prepare50("model/prepared/prepared_resnet50_2.pt")